import networkx as nx
import argparse

class DFAGenerator:
    def __init__(self, topo_file, output):
        self.output = output
        self.topo = nx.MultiGraph()
        for line in open(topo_file):
            arr = line.split()
            self.topo.add_edge(arr[0], arr[2], portmap={arr[0]: arr[1], arr[2]: arr[3]})

    def simple_path(self, src, dst, length=None):
        paths = nx.all_simple_paths(self.topo, src, dst, length)
        # ps = [list(p) for p in map(nx.utils.pairwise, paths)]
        ps = set(
            tuple(path)  # Sets can't be used with lists because they are not hashable
            for path in paths
        )
        # ps = set(ps)
        print((ps))
        print(len(ps))
        dfas = []
        for path in ps:
            dfa = nx.DiGraph()
            for pair in path:
                dfa.add_edge(pair[0], pair[1])
            dfas.append(dfa)

        for dfa in dfas:
            print(dfa)




    def ecmp(self, src, dst):
        paths = [p for p in nx.all_shortest_paths(self.topo, src, dst)]
        print('ecmp paths: %d' % len(paths))
        dfa = nx.DiGraph()
        for path in paths:
            for i in range(len(path) - 1):
                s = path[i]
                d = path[i + 1]
                if (s, d) not in dfa.edges:
                    dfa.add_edge(s, d)
                    
        print('DFA edges: %d' % len(dfa.edges))
        with open(self.output, 'w') as f:
            for e in dfa.edges:
                f.write('%s %s\n' % (e[0], e[1]))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("topo", help="the topo file")
    parser.add_argument("src", help="the src node")
    parser.add_argument("dst", help="the dst node")
    parser.add_argument("output", help="the output dfa file")
    
    args = parser.parse_args()
    dg = DFAGenerator(args.topo, args.output)
    dg.ecmp(args.src, args.dst)